import json
import os
import csv
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
import torch
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer
from heads import get_matching_head 
from aadata import TextPairDataset 

def custom_collate_fn(batch):
    token_id_lists, reasons, labels, sample_types = zip(*batch)
    token_id_lists = [torch.tensor(x, dtype=torch.long) for x in token_id_lists]
    labels = torch.tensor(labels, dtype=torch.float)
    return token_id_lists, list(reasons), labels, list(sample_types)


class MatchingInference:
    def __init__(self, model_dir):
        self.embedding_model = SentenceTransformer(f"{model_dir}/embedding_model", trust_remote_code=True)
        self.embedding_model = self.embedding_model.cuda()
        self.embedding_model.eval()

        embedding_dim = self.embedding_model.get_sentence_embedding_dimension()
        self.matching_head = get_matching_head("cos_sim", embedding_dim) 
        self.matching_head.load_state_dict(torch.load(f"{model_dir}/matching_head.pt"))
        self.matching_head = self.matching_head.cuda()
        self.matching_head.eval()

        tokenizer = self.embedding_model.tokenizer
        vocab = tokenizer.get_vocab()

        filtered = [(tok, idx) for tok, idx in vocab.items() if not tok.startswith("[") and tok.strip() and tok != tokenizer.unk_token]

        tokens = [x[0] for x in filtered]
        ids = [x[1] for x in filtered]


        with torch.no_grad():

            if tokens:
                 token_embs = self.embedding_model.encode(tokens, convert_to_tensor=True, show_progress_bar=False, normalize_embeddings=True) 
                 self.tokenid2emb = {int(i): emb.cuda() for i, emb in zip(ids, token_embs)}
            else:
                 self.tokenid2emb = {}

        self.vocab_inv = {v: k for k, v in vocab.items()}
        self.unk_token_id = tokenizer.unk_token_id if hasattr(tokenizer, 'unk_token_id') else None


    def get_token_embedding(self, token_ids, token_mode="mean"):
        if token_mode == "mean":

            embs = [self.tokenid2emb[i.item()] for i in token_ids if i.item() in self.tokenid2emb and (self.unk_token_id is None or i.item() != self.unk_token_id) ]
            if embs:
                return torch.stack(embs).mean(dim=0)
            else: 

                if self.tokenid2emb:
                    return torch.zeros_like(next(iter(self.tokenid2emb.values())))
                else: 
                    embedding_dim = self.embedding_model.get_sentence_embedding_dimension()
                    return torch.zeros(embedding_dim, device=self.embedding_model.device)

        elif token_mode == "seq":

            toks = [self.vocab_inv.get(i.item(), "") for i in token_ids if i.item() in self.vocab_inv and (self.unk_token_id is None or i.item() != self.unk_token_id)]
            toks = [tok for tok in toks if tok.strip()] 
            if not toks: 
                 embedding_dim = self.embedding_model.get_sentence_embedding_dimension()
                 return torch.zeros(embedding_dim, device=self.embedding_model.device).cuda() 

            sentence = " ".join(toks)

            return self.embedding_model.encode(sentence, convert_to_tensor=True).squeeze(0)
        else:
            raise ValueError(f"Unsupported token_mode: {token_mode}")


def evaluate(model_dir, data_path, model_name, token_mode="mean", batch_size=128):
    model = MatchingInference(model_dir)

    dataset = TextPairDataset(data_path, tokenizer=model.embedding_model.tokenizer, model_name=None, negative_types=None)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn)

    all_scores = []
    all_labels = []
    all_sample_types = []

    for token_id_lists, reasons, batch_labels, batch_sample_types in tqdm(dataloader, desc=f"Evaluating {os.path.basename(data_path)}"):
        emb_b = model.embedding_model.encode(reasons, convert_to_tensor=True, normalize_embeddings=True)
        
        emb_a_list = []
        for token_ids in token_id_lists:

            emb_a_single = model.get_token_embedding(token_ids.cuda() if token_ids.is_cuda else token_ids, token_mode=token_mode)
            emb_a_list.append(emb_a_single)

        emb_a = torch.stack(emb_a_list).to(emb_b.device) 

        features = {"embedding_a": emb_a, "embedding_b": emb_b}
        with torch.no_grad():
            logits = model.matching_head(features)["logits"].squeeze(-1)
            scores = torch.sigmoid(logits)

        all_scores.extend(scores.cpu().tolist())
        all_labels.extend(batch_labels.cpu().tolist()) 
        all_sample_types.extend(batch_sample_types)


    pred_labels_overall = [1 if p >= 0.5 else 0 for p in all_scores]
    overall_acc = accuracy_score(all_labels, pred_labels_overall)
    overall_f1 = f1_score(all_labels, pred_labels_overall)
    try:
        overall_auc = roc_auc_score(all_labels, all_scores)
    except ValueError:
        overall_auc = 0.0 if len(set(all_labels)) < 2 else -1.0

    result_dict = {
        "model_dir": model_dir,
        "data_path": data_path,
        "token_mode": token_mode,
        "overall_accuracy": round(overall_acc, 4),
        "overall_f1_score": round(overall_f1, 4),
        "overall_auc_score": round(overall_auc, 4)
    }

    typed_predictions = {}
    typed_true_labels = {}

    all_possible_types = ["positive"] + dataset.negative_types_keys
    for s_type in all_possible_types:
        typed_predictions[s_type] = []
        typed_true_labels[s_type] = []

    for score, true_label, sample_type_str in zip(all_scores, all_labels, all_sample_types):
        pred_label = 1 if score >= 0.5 else 0
        if sample_type_str in typed_predictions:
            typed_predictions[sample_type_str].append(pred_label)
            typed_true_labels[sample_type_str].append(int(true_label))

    for s_type in all_possible_types:
        if typed_true_labels[s_type]:
            acc_type = accuracy_score(typed_true_labels[s_type], typed_predictions[s_type])
            result_dict[f"acc_{s_type.replace('_tokens', '')}"] = round(acc_type, 4)
        else:
            result_dict[f"acc_{s_type.replace('_tokens', '')}"] = "N/A"

    return result_dict


def main_auto_eval(input_root, model_dir, output_dir, model_name=None, token_mode="mean"):
    os.makedirs(output_dir, exist_ok=True)
    results = []
    first_result_keys = None

    for subdir in os.listdir(input_root):
        sub_path = os.path.join(input_root, subdir)
        if not os.path.isdir(sub_path):
            continue

        for file in os.listdir(sub_path):
            if file.endswith(".json"):
                json_path = os.path.join(sub_path, file)
                print(f"INFO: Evaluating {json_path}...")
                try:
                    if "pos" not in file:
                        print("skip")
                        continue
                    result = evaluate(model_dir, json_path, model_name, token_mode=token_mode)
                    results.append(result)
                    if first_result_keys is None:
                        first_result_keys = list(result.keys())
                    elif set(first_result_keys) != set(result.keys()):
                        print(f"WARNING: Key mismatch between results. Header from first result: {first_result_keys}. Current keys: {list(result.keys())}")

                except Exception as e:
                    print(f"Failed to evaluate {json_path}: {e}")
                    import traceback
                    traceback.print_exc()


    if results:
        all_keys = set()
        for res in results:
            all_keys.update(res.keys())
        fieldnames = sorted(list(all_keys))

        if first_result_keys:
             fieldnames = first_result_keys

        csv_path = os.path.join(output_dir, f"new_all_eval_results_{token_mode}.csv")
        with open(csv_path, "w", newline='', encoding="utf-8") as f:
            writer = csv.DictWriter(f, fieldnames=fieldnames)
            writer.writeheader()
            for row_data in results:

                writer.writerow({k: row_data.get(k, "") for k in fieldnames})

        print(f"\nAll results saved to {csv_path}")
    else:
        print("No results to save.")


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_dir", type=str, required=True)
    parser.add_argument("--input_root", type=str, required=True, help="Root directory containing subfolders with json data")
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument("--token_mode", type=str, choices=["mean", "seq"], default="mean")

    parser.add_argument("--model_name_for_st", type=str, default=None, help="SentenceTransformer model name if tokenizer needs to be loaded by dataset (deprecated if model_dir provides tokenizer)")
    args = parser.parse_args()

    main_auto_eval(args.input_root, args.model_dir, args.output_dir, model_name=args.model_name_for_st, token_mode=args.token_mode)